{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.3 PyTorchでMNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 手書き数字の画像データMNISTをダウンロード\n",
    "\n",
    "# from sklearn.datasets import fetch_mldata\n",
    "# mnist = fetch_mldata('MNIST original', data_home=\".\")  # data_homeは保存先を指定します\n",
    "\n",
    "# 2019年1月31日訂正\n",
    "# 上記コードでは、以下のエラーが発生します\n",
    "#  [WinError 10060] 接続済みの呼び出し先が一定の時間を過ぎても正しく応答しなかったため、接続できませんでした。または接続済みのホストが応答しなかったため、確立された接続は失敗しました。\n",
    "\n",
    "from sklearn.datasets import fetch_openml\n",
    "mnist = fetch_openml('mnist_784', version=1, data_home=\".\")  # data_homeは保存先を指定します\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. データの前処理（画像データとラベルに分割し、正規化）\n",
    "\n",
    "X = mnist.data / 255  # 0-255を0-1に正規化\n",
    "y = mnist.target\n",
    "\n",
    "# 2019年1月31日訂正\n",
    "# MNISTのデータセットの変更により、ラベルが数値データになっていないので、\n",
    "# 以下により、NumPyの配列の数値型に変換します\n",
    "\n",
    "import numpy as np\n",
    "y = np.array(y)\n",
    "y = y.astype(np.int32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "この画像データのラベルは5です\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAADgpJREFUeJzt3X+MVfWZx/HPs1j+kKI4aQRCYSnEYJW4082IjSWrxkzVDQZHrekkJjQapn8wiU02ZA3/VNNgyCrslmiamaZYSFpKE3VB0iw0otLGZuKIWC0srTFsO3IDNTjywx9kmGf/mEMzxbnfe+fec++5zPN+JeT+eM6558kNnznn3O+592vuLgDx/EPRDQAoBuEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxDUZc3cmJlxOSHQYO5u1SxX157fzO40syNm9q6ZPVrPawFoLqv12n4zmybpj5I6JQ1Jel1St7sfSqzDnh9osGbs+ZdJetfd33P3c5J+IWllHa8HoInqCf88SX8Z93goe+7vmFmPmQ2a2WAd2wKQs3o+8Jvo0OJzh/Xu3i+pX+KwH2gl9ez5hyTNH/f4y5KO1dcOgGapJ/yvS7rGzL5iZtMlfVvSrnzaAtBoNR/2u/uImfVK2iNpmqQt7v6H3DoD0FA1D/XVtDHO+YGGa8pFPgAuXYQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EVfMU3ZJkZkclnZZ0XtKIu3fk0RTyM23atGT9yiuvbOj2e3t7y9Yuv/zy5LpLlixJ1tesWZOsP/XUU2Vr3d3dyXU//fTTZH3Dhg3J+uOPP56st4K6wp+5zd0/yOF1ADQRh/1AUPWG3yXtNbM3zKwnj4YANEe9h/3fcPdjZna1pF+b2f+6+/7xC2R/FPjDALSYuvb87n4suz0h6QVJyyZYpt/dO/gwEGgtNYffzGaY2cwL9yV9U9I7eTUGoLHqOeyfLekFM7vwOj939//JpSsADVdz+N39PUn/lGMvU9aCBQuS9enTpyfrN998c7K+fPnysrVZs2Yl173vvvuS9SINDQ0l65s3b07Wu7q6ytZOnz6dXPett95K1l999dVk/VLAUB8QFOEHgiL8QFCEHwiK8ANBEX4gKHP35m3MrHkba6L29vZkfd++fcl6o79W26pGR0eT9YceeihZP3PmTM3bLpVKyfqHH36YrB85cqTmbTeau1s1y7HnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgGOfPQVtbW7I+MDCQrC9atCjPdnJVqffh4eFk/bbbbitbO3fuXHLdqNc/1ItxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QVB6z9IZ38uTJZH3t2rXJ+ooVK5L1N998M1mv9BPWKQcPHkzWOzs7k/WzZ88m69dff33Z2iOPPJJcF43Fnh8IivADQRF+ICjCDwRF+IGgCD8QFOEHgqr4fX4z2yJphaQT7r40e65N0g5JCyUdlfSAu6d/6FxT9/v89briiiuS9UrTSff19ZWtPfzww8l1H3zwwWR9+/btyTpaT57f5/+ppDsveu5RSS+5+zWSXsoeA7iEVAy/u++XdPElbCslbc3ub5V0T859AWiwWs/5Z7t7SZKy26vzawlAMzT82n4z65HU0+jtAJicWvf8x81sriRltyfKLeju/e7e4e4dNW4LQAPUGv5dklZl91dJ2plPOwCapWL4zWy7pN9JWmJmQ2b2sKQNkjrN7E+SOrPHAC4hFc/53b27TOn2nHsJ69SpU3Wt/9FHH9W87urVq5P1HTt2JOujo6M1bxvF4go/ICjCDwRF+IGgCD8QFOEHgiL8QFBM0T0FzJgxo2ztxRdfTK57yy23JOt33XVXsr53795kHc3HFN0Akgg/EBThB4Ii/EBQhB8IivADQRF+ICjG+ae4xYsXJ+sHDhxI1oeHh5P1l19+OVkfHBwsW3vmmWeS6zbz/+ZUwjg/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf7gurq6kvVnn302WZ85c2bN2163bl2yvm3btmS9VCrVvO2pjHF+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBUxXF+M9siaYWkE+6+NHvuMUmrJf01W2ydu/+q4sYY57/kLF26NFnftGlTsn777bXP5N7X15esr1+/Pll///33a972pSzPcf6fSrpzguf/093bs38Vgw+gtVQMv7vvl3SyCb0AaKJ6zvl7zez3ZrbFzK7KrSMATVFr+H8kabGkdkklSRvLLWhmPWY2aGblf8wNQNPVFH53P+7u5919VNKPJS1LLNvv7h3u3lFrkwDyV1P4zWzuuIddkt7Jpx0AzXJZpQXMbLukWyV9ycyGJH1f0q1m1i7JJR2V9N0G9gigAfg+P+oya9asZP3uu+8uW6v0WwFm6eHqffv2JeudnZ3J+lTF9/kBJBF+ICjCDwRF+IGgCD8QFOEHgmKoD4X57LPPkvXLLktfhjIyMpKs33HHHWVrr7zySnLdSxlDfQCSCD8QFOEHgiL8QFCEHwiK8ANBEX4gqIrf50dsN9xwQ7J+//33J+s33nhj2VqlcfxKDh06lKzv37+/rtef6tjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPNPcUuWLEnWe3t7k/V77703WZ8zZ86ke6rW+fPnk/VSqZSsj46O5tnOlMOeHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCqjjOb2bzJW2TNEfSqKR+d/+hmbVJ2iFpoaSjkh5w9w8b12pclcbSu7u7y9YqjeMvXLiwlpZyMTg4mKyvX78+Wd+1a1ee7YRTzZ5/RNK/uftXJX1d0hozu07So5JecvdrJL2UPQZwiagYfncvufuB7P5pSYclzZO0UtLWbLGtku5pVJMA8jepc34zWyjpa5IGJM1295I09gdC0tV5Nwegcaq+tt/MvijpOUnfc/dTZlVNByYz65HUU1t7ABqlqj2/mX1BY8H/mbs/nz193MzmZvW5kk5MtK6797t7h7t35NEwgHxUDL+N7eJ/Iumwu28aV9olaVV2f5Wknfm3B6BRKk7RbWbLJf1G0tsaG+qTpHUaO+//paQFkv4s6VvufrLCa4Wconv27NnJ+nXXXZesP/3008n6tddeO+me8jIwMJCsP/nkk2VrO3em9xd8Jbc21U7RXfGc391/K6nci90+maYAtA6u8AOCIvxAUIQfCIrwA0ERfiAowg8ExU93V6mtra1sra+vL7lue3t7sr5o0aKaesrDa6+9lqxv3LgxWd+zZ0+y/sknn0y6JzQHe34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCCrMOP9NN92UrK9duzZZX7ZsWdnavHnzauopLx9//HHZ2ubNm5PrPvHEE8n62bNna+oJrY89PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EFWacv6urq656PQ4dOpSs7969O1kfGRlJ1lPfuR8eHk6ui7jY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUObu6QXM5kvaJmmOpFFJ/e7+QzN7TNJqSX/NFl3n7r+q8FrpjQGom7tbNctVE/65kua6+wEzmynpDUn3SHpA0hl3f6rapgg/0HjVhr/iFX7uXpJUyu6fNrPDkor96RoAdZvUOb+ZLZT0NUkD2VO9ZvZ7M9tiZleVWafHzAbNbLCuTgHkquJh/98WNPuipFclrXf3581stqQPJLmkH2js1OChCq/BYT/QYLmd80uSmX1B0m5Je9x90wT1hZJ2u/vSCq9D+IEGqzb8FQ/7zcwk/UTS4fHBzz4IvKBL0juTbRJAcar5tH+5pN9IeltjQ32StE5St6R2jR32H5X03ezDwdRrsecHGizXw/68EH6g8XI77AcwNRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCavYU3R9I+r9xj7+UPdeKWrW3Vu1Lorda5dnbP1a7YFO/z/+5jZsNuntHYQ0ktGpvrdqXRG+1Kqo3DvuBoAg/EFTR4e8vePsprdpbq/Yl0VutCumt0HN+AMUpes8PoCCFhN/M7jSzI2b2rpk9WkQP5ZjZUTN728wOFj3FWDYN2gkze2fcc21m9msz+1N2O+E0aQX19piZvZ+9dwfN7F8L6m2+mb1sZofN7A9m9kj2fKHvXaKvQt63ph/2m9k0SX+U1ClpSNLrkrrd/VBTGynDzI5K6nD3wseEzexfJJ2RtO3CbEhm9h+STrr7huwP51Xu/u8t0ttjmuTMzQ3qrdzM0t9Rge9dnjNe56GIPf8ySe+6+3vufk7SLyStLKCPlufu+yWdvOjplZK2Zve3auw/T9OV6a0luHvJ3Q9k909LujCzdKHvXaKvQhQR/nmS/jLu8ZBaa8pvl7TXzN4ws56im5nA7AszI2W3Vxfcz8UqztzcTBfNLN0y710tM17nrYjwTzSbSCsNOXzD3f9Z0l2S1mSHt6jOjyQt1tg0biVJG4tsJptZ+jlJ33P3U0X2Mt4EfRXyvhUR/iFJ88c9/rKkYwX0MSF3P5bdnpD0gsZOU1rJ8QuTpGa3Jwru52/c/bi7n3f3UUk/VoHvXTaz9HOSfubuz2dPF/7eTdRXUe9bEeF/XdI1ZvYVM5su6duSdhXQx+eY2YzsgxiZ2QxJ31TrzT68S9Kq7P4qSTsL7OXvtMrMzeVmllbB712rzXhdyEU+2VDGf0maJmmLu69vehMTMLNFGtvbS2PfePx5kb2Z2XZJt2rsW1/HJX1f0n9L+qWkBZL+LOlb7t70D97K9HarJjlzc4N6Kzez9IAKfO/ynPE6l364wg+IiSv8gKAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8E9f/Ex0YKZYOZcwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# MNISTのデータの1つ目を可視化する\n",
    "\n",
    "# 2019年1月31日訂正\n",
    "# %matplotlibで%のあとに入っていたスペースを削除\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "plt.imshow(X[0].reshape(28, 28), cmap='gray')\n",
    "print(\"この画像データのラベルは{:.0f}です\".format(y[0]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2. DataLoderの作成\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# 2.1 データを訓練とテストに分割（6:1）\n",
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    X, y, test_size=1/7, random_state=0)\n",
    "\n",
    "# 2.2 データをPyTorchのTensorに変換\n",
    "X_train = torch.Tensor(X_train)\n",
    "X_test = torch.Tensor(X_test)\n",
    "y_train = torch.LongTensor(y_train)\n",
    "y_test = torch.LongTensor(y_test)\n",
    "\n",
    "# 2.3 データとラベルをセットにしたDatasetを作成\n",
    "ds_train = TensorDataset(X_train, y_train)\n",
    "ds_test = TensorDataset(X_test, y_test)\n",
    "\n",
    "# 2.4 データセットのミニバッチサイズを指定した、Dataloaderを作成\n",
    "# Chainerのiterators.SerialIteratorと似ている\n",
    "loader_train = DataLoader(ds_train, batch_size=64, shuffle=True)\n",
    "loader_test = DataLoader(ds_test, batch_size=64, shuffle=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (fc1): Linear(in_features=784, out_features=100, bias=True)\n",
      "  (relu1): ReLU()\n",
      "  (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "  (relu2): ReLU()\n",
      "  (fc3): Linear(in_features=100, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 3. ネットワークの構築\n",
    "# Keras風の書き方 \n",
    "\n",
    "from torch import nn\n",
    "\n",
    "model = nn.Sequential()\n",
    "model.add_module('fc1', nn.Linear(28*28*1, 100))\n",
    "model.add_module('relu1', nn.ReLU())\n",
    "model.add_module('fc2', nn.Linear(100, 100))\n",
    "model.add_module('relu2', nn.ReLU())\n",
    "model.add_module('fc3', nn.Linear(100, 10))\n",
    "\n",
    "print(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4. 誤差関数と最適化手法の設定\n",
    "\n",
    "from torch import optim\n",
    "\n",
    "# 誤差関数の設定\n",
    "loss_fn = nn.CrossEntropyLoss()  # 変数名にはcriterionが使われることも多い\n",
    "\n",
    "# 重みを学習する際の最適化手法の選択\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.01)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 5. 学習と推論の設定\n",
    "# 5-1. 学習1回でやることを定義します\n",
    "# Chainerのtraining.Trainer()に対応するものはない\n",
    "\n",
    "\n",
    "def train(epoch):\n",
    "    model.train()  # ネットワークを学習モードに切り替える\n",
    "\n",
    "    # データローダーから1ミニバッチずつ取り出して計算する\n",
    "    for data, targets in loader_train:\n",
    "      \n",
    "        optimizer.zero_grad()  # 一度計算された勾配結果を0にリセット\n",
    "        outputs = model(data)  # 入力dataをinputし、出力を求める\n",
    "        loss = loss_fn(outputs, targets)  # 出力と訓練データの正解との誤差を求める\n",
    "        loss.backward()  # 誤差のバックプロパゲーションを求める\n",
    "        optimizer.step()  # バックプロパゲーションの値で重みを更新する\n",
    "\n",
    "    print(\"epoch{}：終了\\n\".format(epoch))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 5. 学習と推論の設定\n",
    "# 5-2. 推論1回でやることを定義します\n",
    "# Chainerのtrainer.extend(extensions.Evaluator())に対応するものはない\n",
    "\n",
    "\n",
    "def test():\n",
    "    model.eval()  # ネットワークを推論モードに切り替える\n",
    "    correct = 0\n",
    "\n",
    "    # データローダーから1ミニバッチずつ取り出して計算する\n",
    "    with torch.no_grad():  # 微分は推論では必要ない\n",
    "        for data, targets in loader_test:\n",
    "\n",
    "            outputs = model(data)  # 入力dataをinputし、出力を求める\n",
    "\n",
    "            # 推論する\n",
    "            _, predicted = torch.max(outputs.data, 1)  # 確率が最大のラベルを求める\n",
    "            correct += predicted.eq(targets.data.view_as(predicted)).sum()  # 正解と一緒だったらカウントアップ\n",
    "\n",
    "    # 正解率を出力\n",
    "    data_num = len(loader_test.dataset)  # データの総数\n",
    "    print('\\nテストデータの正解率: {}/{} ({:.0f}%)\\n'.format(correct,\n",
    "                                                   data_num, 100. * correct / data_num))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "テストデータの正解率: 1191/10000 (11%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 学習なしにテストデータで推論してみよう\n",
    "test()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch0：終了\n",
      "\n",
      "epoch1：終了\n",
      "\n",
      "epoch2：終了\n",
      "\n",
      "\n",
      "テストデータの正解率: 9572/10000 (95%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 6. 学習と推論の実行\n",
    "for epoch in range(3):\n",
    "    train(epoch)\n",
    "\n",
    "test()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "予測結果は2\n",
      "この画像データの正解ラベルは2です\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAADlBJREFUeJzt3X2IVOcVx/HfSaIEtkWyGK1sNbZGmlT/2IZFAi3VJlhsKNESmigYDGm6QhrSQkMSzB8NNEJjao2QUFitaKDV1tWohBJb89qQYrIGqan2JRFb1xXXYEPTEDTG0z/2btnq3mfGmTtzxz3fD4R5OXPvPdz42zszz73zmLsLQDyXld0AgHIQfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQV3RzI2ZGacTAg3m7lbN6+o68pvZAjP7q5m9Y2YP17MuAM1ltZ7bb2aXS/qbpPmS+iW9KWmJux9MLMORH2iwZhz550h6x90Pu/sZSVskLaxjfQCaqJ7wd0g6OuJxf/bc/zGzbjPrM7O+OrYFoGD1fOE32luLC97Wu3uPpB6Jt/1AK6nnyN8vaeqIx5+VNFBfOwCapZ7wvylpppl9zszGS1osaVcxbQFotJrf9rv7WTO7T9JuSZdL2uDufy6sMwANVfNQX00b4zM/0HBNOckHwKWL8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaBqnqJbkszsiKQPJH0i6ay7dxXRFFrHrFmzkvVFixYl67feemturaurvn8ur732WrL+4IMP5tb27t1b17bHgrrCn/mau79XwHoANBFv+4Gg6g2/S/qdme0zs+4iGgLQHPW+7f+yuw+Y2SRJvzezv7j7qyNfkP1R4A8D0GLqOvK7+0B2OyjpWUlzRnlNj7t38WUg0FpqDr+ZtZnZp4fvS/q6pLeLagxAY9Xztn+ypGfNbHg9v3L35wvpCkDDmbs3b2NmzdtYIKmx+Pnz5yeXTY3DS9LcuXOT9Wb++zlfduDJNTg4mFu7/vrrk8u+//77NfXUCtw9vWMyDPUBQRF+ICjCDwRF+IGgCD8QFOEHgiriqj402F133ZWsr1q1KrfW3t5ecDfFOXToULK+devWZP2WW25J1lOXDHd3p884T+3TsYIjPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExTh/C2hra0vW77///mS9zLH8kydPJuubNm3KrT311FPJZfv7+5P1zs7OZD3lyiuvrHnZsYIjPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExTh/Czh79myyfubMmSZ1cqElS5Yk66+//nqyXmmsvh4LFy5M1lM/K37gwIGi27nkcOQHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAqjvOb2QZJ35Q06O6zs+faJf1a0nRJRyTd7u7/alybY9vp06eT9RtvvDFZnz17dm7tjjvuSC67Zs2aZP3UqVPJej0q/Y7BQw89lKxfdln62LVv377c2vPPP59cNoJqjvwbJS0477mHJb3g7jMlvZA9BnAJqRh+d39V0vl//hdKGv6Jlk2SFhXcF4AGq/Uz/2R3Py5J2e2k4loC0AwNP7ffzLolpSdGA9B0tR75T5jZFEnKbgfzXujuPe7e5e75syYCaLpaw79L0rLs/jJJO4tpB0CzVAy/mW2W9EdJXzCzfjP7jqSfSJpvZn+XND97DOASYqlrngvfmFnzNgZJUkdHR7J+7NixJnVyoXnz5iXre/bsSdbNLFlfunRpbm3z5s3JZS9l7p7eMRnO8AOCIvxAUIQfCIrwA0ERfiAowg8ExU93j3FlDuVJ0sSJE3Nrq1atqmvd69evT9Z7e3vrWv9Yx5EfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Likl7UpbOzM1nv6enJrd1www3JZQcGBpL1adOmJetRcUkvgCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiK6/mR1N7enqxv2bIlWb/22mtza5XG8RcsOH9yaBSJIz8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBFVxnN/MNkj6pqRBd5+dPfeopO9KOpm9bIW7/7ZRTaJxKo3jv/zyy8n6zJkzk/WTJ0/m1u65557ksgcPHkzWUZ9qjvwbJY12tsUad+/M/iP4wCWmYvjd/VVJp5rQC4Amqucz/31m9icz22BmVxXWEYCmqDX8P5c0Q1KnpOOSVue90My6zazPzPpq3BaABqgp/O5+wt0/cfdzktZJmpN4bY+7d7l7V61NAiheTeE3sykjHn5L0tvFtAOgWaoZ6tssaZ6kiWbWL+lHkuaZWackl3RE0vIG9gigAfjd/jFu0qRJyfrOnTuT9Tlzcj/RSZKOHj2arD/wwAO5td7e3uSyqA2/2w8gifADQRF+ICjCDwRF+IGgCD8QFEN9BZgwYUKyvmzZsmT9kUceSdbr+X80bty4ZL1S72bpUaPbbrstWd+xY0eyjuIx1AcgifADQRF+ICjCDwRF+IGgCD8QFOEHgmKcv0rXXXddbm337t3JZTs6OpL1vr70L5x1dZX3I0iVxvkrXdL79NNP59Y2btyYXDb1s9/Ixzg/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf7MokWLkvU1a9bk1vbs2VPzspK0ePHiZH3FihXJesrAwECyvnLlymT93nvvTdZnzZp10T0NO3bsWLK+bt26ZP2xxx6redtjGeP8AJIIPxAU4QeCIvxAUIQfCIrwA0ERfiCoiuP8ZjZV0jOSPiPpnKQed19rZu2Sfi1puqQjkm53939VWFfLjvO/9NJLyXrq2vLVq1cnl3388ceT9blz5ybr586dS9bXr1+fW1u+fHly2XqlpuCuVL/66qvr2vbhw4eT9c7Oztzahx9+WNe2W1mR4/xnJf3Q3a+XdKOk75nZFyU9LOkFd58p6YXsMYBLRMXwu/txd38ru/+BpEOSOiQtlLQpe9kmSelT5AC0lIv6zG9m0yV9SdJeSZPd/bg09AdC0qSimwPQOFdU+0Iz+5SkbZJ+4O7/rvTbbiOW65bUXVt7ABqlqiO/mY3TUPB/6e7bs6dPmNmUrD5F0uBoy7p7j7t3uXt5v0IJ4AIVw29Dh/hfSDrk7j8bUdolaXj62WWSdhbfHoBGqWao7yuS/iDpgIaG+iRphYY+9/9G0jRJ/5T0bXc/VWFdLTvU9+KLLybr11xzTW6tra0tuezEiROT9f379yfrlYYSe3t7c2sff/xxctlGmz59em6t0qXKd999d7Je6aPntm3bcmt33nlnctnTp08n662s2qG+ip/53f01SXkru/limgLQOjjDDwiK8ANBEX4gKMIPBEX4gaAIPxAUP92d2b59e7J+00035dbefffd5LI7d6bPf3riiSeS9Y8++ihZv1SNHz8+Wa90OfKTTz6ZrKf+bVeaVn3r1q3JeqXpxcvET3cDSCL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY56/SjBkzcmuVxvnRGGvXrk3Wly5dmlubMGFCctlXXnklWb/55ta9mp1xfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8GLNSY/E7duxILvvGG2/UvO6yMc4PIInwA0ERfiAowg8ERfiBoAg/EBThB4KqOM5vZlMlPSPpM5LOSepx97Vm9qik70o6mb10hbv/tsK6GOcHGqzacf5qwj9F0hR3f8vMPi1pn6RFkm6X9B93/2m1TRF+oPGqDf8VVazouKTj2f0PzOyQpI762gNQtov6zG9m0yV9SdLe7Kn7zOxPZrbBzK7KWabbzPrMrK+uTgEUqupz+83sU5JekbTS3beb2WRJ70lyST/W0EeDuyusg7f9QIMV9plfksxsnKTnJO1295+NUp8u6Tl3n11hPYQfaLDCLuwxM5P0C0mHRgY/+yJw2LckvX2xTQIoTzXf9n9F0h8kHdDQUJ8krZC0RFKnht72H5G0PPtyMLUujvxAgxX6tr8ohB9oPK7nB5BE+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCKriD3gW7D1J/xjxeGL2XCtq1d5atS+J3mpVZG/XVPvCpl7Pf8HGzfrcvau0BhJatbdW7Uuit1qV1Rtv+4GgCD8QVNnh7yl5+ymt2lur9iXRW61K6a3Uz/wAylP2kR9ASUoJv5ktMLO/mtk7ZvZwGT3kMbMjZnbAzPaXPcVYNg3aoJm9PeK5djP7vZn9PbsddZq0knp71MyOZftuv5ndUlJvU83sJTM7ZGZ/NrPvZ8+Xuu8SfZWy35r+tt/MLpf0N0nzJfVLelPSEnc/2NRGcpjZEUld7l76mLCZfVXSfyQ9MzwbkpmtknTK3X+S/eG8yt0fapHeHtVFztzcoN7yZpa+SyXuuyJnvC5CGUf+OZLecffD7n5G0hZJC0voo+W5+6uSTp339EJJm7L7mzT0j6fpcnprCe5+3N3fyu5/IGl4ZulS912ir1KUEf4OSUdHPO5Xa0357ZJ+Z2b7zKy77GZGMXl4ZqTsdlLJ/Zyv4szNzXTezNIts+9qmfG6aGWEf7TZRFppyOHL7n6DpG9I+l729hbV+bmkGRqaxu24pNVlNpPNLL1N0g/c/d9l9jLSKH2Vst/KCH+/pKkjHn9W0kAJfYzK3Qey20FJz2roY0orOTE8SWp2O1hyP//j7ifc/RN3PydpnUrcd9nM0tsk/dLdt2dPl77vRuurrP1WRvjflDTTzD5nZuMlLZa0q4Q+LmBmbdkXMTKzNklfV+vNPrxL0rLs/jJJO0vs5f+0yszNeTNLq+R912ozXpdykk82lPGkpMslbXD3lU1vYhRm9nkNHe2loSsef1Vmb2a2WdI8DV31dULSjyTtkPQbSdMk/VPSt9296V+85fQ2Txc5c3ODesubWXqvStx3Rc54XUg/nOEHxMQZfkBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgvov7/doAPsQ7IQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 例えば2018番目の画像データを推論してみる\n",
    "\n",
    "index = 2018\n",
    "\n",
    "model.eval()  # ネットワークを推論モードに切り替える\n",
    "data = X_test[index]\n",
    "output = model(data)  # 入力dataをinputし、出力を求める\n",
    "_, predicted = torch.max(output.data, 0)  # 確率が最大のラベルを求める\n",
    "\n",
    "print(\"予測結果は{}\".format(predicted))\n",
    "\n",
    "X_test_show = (X_test[index]).numpy()\n",
    "plt.imshow(X_test_show.reshape(28, 28), cmap='gray')\n",
    "print(\"この画像データの正解ラベルは{:.0f}です\".format(y_test[index]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "#-----------------------------------------------"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Net(\n",
      "  (fc1): Linear(in_features=784, out_features=100, bias=True)\n",
      "  (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "  (fc3): Linear(in_features=100, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 3. ネットワークの構築\n",
    "# ニューラルネットワークの設定（Chainer風の書き方）\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "\n",
    "    def __init__(self, n_in, n_mid, n_out):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(n_in, n_mid)  # Chainerと異なり、Noneは受けつけない\n",
    "        self.fc2 = nn.Linear(n_mid, n_mid)\n",
    "        self.fc3 = nn.Linear(n_mid, n_out)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 入力xに合わせてforwardの計算を変えられる\n",
    "        h1 = F.relu(self.fc1(x))\n",
    "        h2 = F.relu(self.fc2(h1))\n",
    "        output = self.fc3(h2)\n",
    "        return output\n",
    "\n",
    "\n",
    "model = Net(n_in=28*28*1, n_mid=100, n_out=10)  # ネットワークのオブジェクトを生成\n",
    "print(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
